{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Shakespeare Language Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.utils.rnn as rnn\n",
    "from torch.utils.data import Dataset, DataLoader, TensorDataset\n",
    "import numpy as np\n",
    "\n",
    "import shakespeare_data as sh\n",
    "\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "DEVICE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fixed length input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1609\n",
      " THE SONNETS\n",
      " by William Shakespeare\n",
      "                      1\n",
      "   From fairest creatures we desire increase,\n",
      "   That thereby beauty's rose might never die,\n",
      "   But as the riper should by time decease,\n",
      "...,\n",
      "   And new pervert a reconciled maid.'\n",
      " THE END\n",
      "\n",
      "Total character count: 5551930\n",
      "Unique character count: 84\n",
      "\n",
      "(5551930,)\n",
      "[12 17 11 20  0  1 45 33 30  1 44 40 39 39 30 45 44]\n",
      "1609\n",
      " THE SONNETS\n"
     ]
    }
   ],
   "source": [
    "# Data - refer to shakespeare_data.py for details\n",
    "corpus = sh.read_corpus()\n",
    "print(\"{}...{}\".format(corpus[:203], corpus[-50:]))\n",
    "print(\"Total character count: {}\".format(len(corpus)))\n",
    "chars, charmap = sh.get_charmap(corpus)\n",
    "charcount = len(chars)\n",
    "print(\"Unique character count: {}\\n\".format(len(chars)))\n",
    "shakespeare_array = sh.map_corpus(corpus, charmap)\n",
    "print(shakespeare_array.shape)\n",
    "print(shakespeare_array[:17])\n",
    "print(sh.to_text(shakespeare_array[:17],chars))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset class. Transforme raw text into a set of sequences of fixed length, and extracts inputs and targets\n",
    "class TextDataset(Dataset):\n",
    "    def __init__(self,text, seq_len = 200):\n",
    "        n_seq = len(text) // seq_len\n",
    "        text = text[:n_seq * seq_len]\n",
    "        self.data = torch.tensor(text).view(-1,seq_len)\n",
    "    def __getitem__(self,i):\n",
    "        txt = self.data[i]\n",
    "        return txt[:-1],txt[1:]\n",
    "    def __len__(self):\n",
    "        return self.data.size(0)\n",
    "\n",
    "# Collate function. Transform a list of sequences into a batch. Passed as an argument to the DataLoader.\n",
    "# Returns data on the format seq_len x batch_size\n",
    "def collate(seq_list):\n",
    "    inputs = torch.cat([s[0].unsqueeze(1) for s in seq_list],dim=1)\n",
    "    targets = torch.cat([s[1].unsqueeze(1) for s in seq_list],dim=1)\n",
    "    return inputs,targets\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model\n",
    "class CharLanguageModel(nn.Module):\n",
    "\n",
    "    def __init__(self,vocab_size,embed_size,hidden_size, nlayers):\n",
    "        super(CharLanguageModel,self).__init__()\n",
    "        self.vocab_size=vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.nlayers=nlayers\n",
    "        self.embedding = nn.Embedding(vocab_size,embed_size) # Embedding layer\n",
    "        self.rnn = nn.LSTM(input_size = embed_size,hidden_size=hidden_size,num_layers=nlayers) # Recurrent network\n",
    "        self.scoring = nn.Linear(hidden_size,vocab_size) # Projection layer\n",
    "        \n",
    "    def forward(self,seq_batch): #L x N\n",
    "        # returns 3D logits\n",
    "        batch_size = seq_batch.size(1)\n",
    "        embed = self.embedding(seq_batch) #L x N x E\n",
    "        hidden = None\n",
    "        output_lstm,hidden = self.rnn(embed,hidden) #L x N x H\n",
    "        output_lstm_flatten = output_lstm.view(-1,self.hidden_size) #(L*N) x H\n",
    "        output_flatten = self.scoring(output_lstm_flatten) #(L*N) x V\n",
    "        return output_flatten.view(-1,batch_size,self.vocab_size)\n",
    "    \n",
    "    def generate(self,seq, n_words): # L x V\n",
    "        # performs greedy search to extract and return words (one sequence).\n",
    "        generated_words = []\n",
    "        embed = self.embedding(seq).unsqueeze(1) # L x 1 x E\n",
    "        hidden = None\n",
    "        output_lstm, hidden = self.rnn(embed,hidden) # L x 1 x H\n",
    "        output = output_lstm[-1] # 1 x H\n",
    "        scores = self.scoring(output) # 1 x V\n",
    "        _,current_word = torch.max(scores,dim=1) # 1 x 1\n",
    "        generated_words.append(current_word)\n",
    "        if n_words > 1:\n",
    "            for i in range(n_words-1):\n",
    "                embed = self.embedding(current_word).unsqueeze(0) # 1 x 1 x E\n",
    "                output_lstm, hidden = self.rnn(embed,hidden) # 1 x 1 x H\n",
    "                output = output_lstm[0] # 1 x H\n",
    "                scores = self.scoring(output) # V\n",
    "                _,current_word = torch.max(scores,dim=1) # 1\n",
    "                generated_words.append(current_word)\n",
    "        return torch.cat(generated_words,dim=0)\n",
    "        \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(model, optimizer, train_loader, val_loader):\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    criterion = criterion.to(DEVICE)\n",
    "    batch_id=0\n",
    "    for inputs,targets in train_loader:\n",
    "        batch_id+=1\n",
    "        inputs = inputs.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "        outputs = model(inputs) # 3D\n",
    "        loss = criterion(outputs.view(-1,outputs.size(2)),targets.view(-1)) # Loss of the flattened outputs\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_id % 100 == 0:\n",
    "            lpw = loss.item()\n",
    "            print(\"At batch\",batch_id)\n",
    "            print(\"Training loss per word:\",lpw)\n",
    "            print(\"Training perplexity :\",np.exp(lpw))\n",
    "    \n",
    "    val_loss = 0\n",
    "    batch_id=0\n",
    "    for inputs,targets in val_loader:\n",
    "        batch_id+=1\n",
    "        inputs = inputs.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs.view(-1,outputs.size(2)),targets.view(-1))\n",
    "        val_loss+=loss.item()\n",
    "    val_lpw = val_loss / batch_id\n",
    "    print(\"\\nValidation loss per word:\",val_lpw)\n",
    "    print(\"Validation perplexity :\",np.exp(val_lpw),\"\\n\")\n",
    "    return val_lpw\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CharLanguageModel(charcount,256,256,3)\n",
    "model = model.to(DEVICE)\n",
    "optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=1e-6)\n",
    "split = 5000000\n",
    "train_dataset = TextDataset(shakespeare_array[:split])\n",
    "val_dataset = TextDataset(shakespeare_array[split:])\n",
    "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64, collate_fn = collate)\n",
    "val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, collate_fn = collate, drop_last=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "At batch 100\n",
      "Training loss per word: 2.794588565826416\n",
      "Training perplexity : 16.3558979932\n",
      "At batch 200\n",
      "Training loss per word: 2.088837146759033\n",
      "Training perplexity : 8.07551905871\n",
      "At batch 300\n",
      "Training loss per word: 1.8746708631515503\n",
      "Training perplexity : 6.51867323164\n",
      "\n",
      "Validation loss per word: 1.7766425360080809\n",
      "Validation perplexity : 5.90998052416 \n",
      "\n",
      "At batch 100\n",
      "Training loss per word: 1.6976581811904907\n",
      "Training perplexity : 5.46114339693\n",
      "At batch 200\n",
      "Training loss per word: 1.6234902143478394\n",
      "Training perplexity : 5.07075749808\n",
      "At batch 300\n",
      "Training loss per word: 1.5321810245513916\n",
      "Training perplexity : 4.62826017135\n",
      "\n",
      "Validation loss per word: 1.5648872769156168\n",
      "Validation perplexity : 4.78213584841 \n",
      "\n",
      "At batch 100\n",
      "Training loss per word: 1.4516830444335938\n",
      "Training perplexity : 4.27029556764\n",
      "At batch 200\n",
      "Training loss per word: 1.4193761348724365\n",
      "Training perplexity : 4.13454024001\n",
      "At batch 300\n",
      "Training loss per word: 1.3626270294189453\n",
      "Training perplexity : 3.90644217237\n",
      "\n",
      "Validation loss per word: 1.4870204398798388\n",
      "Validation perplexity : 4.42389460237 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(3):\n",
    "    train_epoch(model, optimizer, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate(model, seed,nwords):\n",
    "    seq = sh.map_corpus(seed, charmap)\n",
    "    seq = torch.tensor(seq).to(DEVICE)\n",
    "    out = model.generate(seq,nwords)\n",
    "    return sh.to_text(out.cpu().detach().numpy(),chars)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "uiet of \n"
     ]
    }
   ],
   "source": [
    "print(generate(model, \"To be, or not to be, that is the q\",8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the state the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake the stake \n"
     ]
    }
   ],
   "source": [
    "print(generate(model, \"Richard \", 1000))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Packed sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1609\n",
      "\n",
      " THE SONNETS\n",
      "\n",
      " by William Shakespeare\n",
      "\n",
      "                      1\n",
      "\n",
      "   From fairest creatures we desire increase,\n",
      "\n",
      "   That thereby beauty's rose might never die,\n",
      "\n",
      "   But as the riper should by time decease,\n",
      "\n",
      "   His tender heir might bear his memory:\n",
      "\n",
      "   But thou contracted to thine own bright eyes,\n",
      "\n",
      "   Feed'st thy light's flame with self-substantial fuel,\n",
      "\n",
      "114638\n"
     ]
    }
   ],
   "source": [
    "stop_character = charmap['\\n']\n",
    "space_character = charmap[\" \"]\n",
    "lines = np.split(shakespeare_array, np.where(shakespeare_array == stop_character)[0]+1) # split the data in lines\n",
    "shakespeare_lines = []\n",
    "for s in lines:\n",
    "    s_trimmed = np.trim_zeros(s-space_character)+space_character # remove space-only lines\n",
    "    if len(s_trimmed)>1:\n",
    "        shakespeare_lines.append(s)\n",
    "for i in range(10):\n",
    "    print(sh.to_text(shakespeare_lines[i],chars))\n",
    "print(len(shakespeare_lines))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinesDataset(Dataset):\n",
    "    def __init__(self,lines):\n",
    "        self.lines=[torch.tensor(l) for l in lines]\n",
    "    def __getitem__(self,i):\n",
    "        line = self.lines[i]\n",
    "        return line[:-1].to(DEVICE),line[1:].to(DEVICE)\n",
    "    def __len__(self):\n",
    "        return len(self.lines)\n",
    "\n",
    "def collate_lines(seq_list):\n",
    "    inputs,targets = zip(*seq_list)\n",
    "    lens = [len(seq) for seq in inputs]\n",
    "    seq_order = sorted(range(len(lens)), key=lens.__getitem__, reverse=True)\n",
    "    inputs = [inputs[i] for i in seq_order]\n",
    "    targets = [targets[i] for i in seq_order]\n",
    "    return inputs,targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model that takes packed sequences in training\n",
    "class PackedLanguageModel(nn.Module):\n",
    "    \n",
    "    def __init__(self,vocab_size,embed_size,hidden_size, nlayers, stop):\n",
    "        super(PackedLanguageModel,self).__init__()\n",
    "        self.vocab_size=vocab_size\n",
    "        self.embed_size = embed_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.nlayers=nlayers\n",
    "        self.embedding = nn.Embedding(vocab_size,embed_size)\n",
    "        self.rnn = nn.LSTM(input_size = embed_size,hidden_size=hidden_size,num_layers=nlayers) # 1 layer, batch_size = False\n",
    "        self.scoring = nn.Linear(hidden_size,vocab_size)\n",
    "        self.stop = stop # stop line character (\\n)\n",
    "    \n",
    "    def forward(self,seq_list): # list\n",
    "        batch_size = len(seq_list)\n",
    "        lens = [len(s) for s in seq_list] # lens of all lines (already sorted)\n",
    "        bounds = [0]\n",
    "        for l in lens:\n",
    "            bounds.append(bounds[-1]+l) # bounds of all lines in the concatenated sequence\n",
    "        seq_concat = torch.cat(seq_list) # concatenated sequence\n",
    "        embed_concat = self.embedding(seq_concat) # concatenated embeddings\n",
    "        embed_list = [embed_concat[bounds[i]:bounds[i+1]] for i in range(batch_size)] # embeddings per line\n",
    "        packed_input = rnn.pack_sequence(embed_list) # packed version\n",
    "        hidden = None\n",
    "        output_packed,hidden = self.rnn(packed_input,hidden)\n",
    "        output_padded, _ = rnn.pad_packed_sequence(output_packed) # unpacked output (padded)\n",
    "        output_flatten = torch.cat([output_padded[:lens[i],i] for i in range(batch_size)]) # concatenated output\n",
    "        scores_flatten = self.scoring(output_flatten) # concatenated logits\n",
    "        return scores_flatten # return concatenated logits\n",
    "    \n",
    "    def generate(self,seq, n_words): # L x V\n",
    "        generated_words = []\n",
    "        embed = self.embedding(seq).unsqueeze(1) # L x 1 x E\n",
    "        hidden = None\n",
    "        output_lstm, hidden = self.rnn(embed,hidden) # L x 1 x H\n",
    "        output = output_lstm[-1] # 1 x H\n",
    "        scores = self.scoring(output) # 1 x V\n",
    "        _,current_word = torch.max(scores,dim=1) # 1 x 1\n",
    "        generated_words.append(current_word)\n",
    "        if n_words > 1:\n",
    "            for i in range(n_words-1):\n",
    "                embed = self.embedding(current_word).unsqueeze(0) # 1 x 1 x E\n",
    "                output_lstm, hidden = self.rnn(embed,hidden) # 1 x 1 x H\n",
    "                output = output_lstm[0] # 1 x H\n",
    "                scores = self.scoring(output) # V\n",
    "                _,current_word = torch.max(scores,dim=1) # 1\n",
    "                generated_words.append(current_word)\n",
    "                if current_word[0].item()==self.stop: # If end of line\n",
    "                    break\n",
    "        return torch.cat(generated_words,dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch_packed(model, optimizer, train_loader, val_loader):\n",
    "    criterion = nn.CrossEntropyLoss(reduction=\"sum\") # sum instead of averaging, to take into account the different lengths\n",
    "    criterion = criterion.to(DEVICE)\n",
    "    batch_id=0\n",
    "    for inputs,targets in train_loader: # lists, presorted, preloaded on GPU\n",
    "        batch_id+=1\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs,torch.cat(targets)) # criterion of the concatenated output\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_id % 100 == 0:\n",
    "            nwords = np.sum(np.array([len(l) for l in inputs]))\n",
    "            lpw = loss.item() / nwords\n",
    "            print(\"At batch\",batch_id)\n",
    "            print(\"Training loss per word:\",lpw)\n",
    "            print(\"Training perplexity :\",np.exp(lpw))\n",
    "    \n",
    "    val_loss = 0\n",
    "    batch_id=0\n",
    "    nwords = 0\n",
    "    for inputs,targets in val_loader:\n",
    "        nwords += np.sum(np.array([len(l) for l in inputs]))\n",
    "        batch_id+=1\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs,torch.cat(targets))\n",
    "        val_loss+=loss.item()\n",
    "    val_lpw = val_loss / nwords\n",
    "    print(\"\\nValidation loss per word:\",val_lpw)\n",
    "    print(\"Validation perplexity :\",np.exp(val_lpw),\"\\n\")\n",
    "    return val_lpw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = PackedLanguageModel(charcount,256,256,3, stop=stop_character)\n",
    "model = model.to(DEVICE)\n",
    "optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=1e-6)\n",
    "split = 100000\n",
    "train_dataset = LinesDataset(shakespeare_lines[:split])\n",
    "val_dataset = LinesDataset(shakespeare_lines[split:])\n",
    "train_loader = DataLoader(train_dataset, shuffle=True, batch_size=64, collate_fn = collate_lines)\n",
    "val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, collate_fn = collate_lines, drop_last=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "At batch 100\n",
      "Training loss per word: 2.69172636927\n",
      "Training perplexity : 14.7571301984\n",
      "At batch 200\n",
      "Training loss per word: 2.14736468192\n",
      "Training perplexity : 8.56226434899\n",
      "At batch 300\n",
      "Training loss per word: 1.90214528113\n",
      "Training perplexity : 6.70025296147\n",
      "At batch 400\n",
      "Training loss per word: 1.76002473874\n",
      "Training perplexity : 5.81258118857\n",
      "At batch 500\n",
      "Training loss per word: 1.73988200182\n",
      "Training perplexity : 5.69667118618\n",
      "At batch 600\n",
      "Training loss per word: 1.72483270255\n",
      "Training perplexity : 5.6115821478\n",
      "At batch 700\n",
      "Training loss per word: 1.63099375979\n",
      "Training perplexity : 5.108949265\n",
      "At batch 800\n",
      "Training loss per word: 1.62051262205\n",
      "Training perplexity : 5.05568130615\n",
      "At batch 900\n",
      "Training loss per word: 1.63384090512\n",
      "Training perplexity : 5.12351591289\n",
      "At batch 1000\n",
      "Training loss per word: 1.61838744287\n",
      "Training perplexity : 5.04494848608\n",
      "At batch 1100\n",
      "Training loss per word: 1.49463654559\n",
      "Training perplexity : 4.45771608182\n",
      "At batch 1200\n",
      "Training loss per word: 1.48435207527\n",
      "Training perplexity : 4.412105774\n",
      "At batch 1300\n",
      "Training loss per word: 1.5630041756\n",
      "Training perplexity : 4.77313907567\n",
      "At batch 1400\n",
      "Training loss per word: 1.48977922348\n",
      "Training perplexity : 4.43611602062\n",
      "At batch 1500\n",
      "Training loss per word: 1.45676226971\n",
      "Training perplexity : 4.29204053788\n",
      "\n",
      "Validation loss per word: 1.54778022547\n",
      "Validation perplexity : 4.70102338013 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(1):\n",
    "    train_epoch_packed(model, optimizer, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "uarrer\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(generate(model, \"To be, or not to be, that is the q\",20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the stand\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(generate(model, \"Richard \", 1000))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (dl)",
   "language": "python",
   "name": "dl"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
